[JAX] TE GMM v2 enforcement Env Var#2794
[JAX] TE GMM v2 enforcement Env Var#2794jberchtold-nvidia wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile SummaryThis PR adds a Key changes in
Confidence Score: 5/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[_can_use_v2_grouped_gemm called] --> B[read enforce_v2_gmm via cached env-var]
B --> C{_v2_grouped_gemm_available?}
C -- No --> D{enforce_v2_gmm?}
D -- Yes --> E[raise RuntimeError with reason string]
D -- No --> F[return False]
C -- Yes --> G{device SM < 100?}
G -- Yes --> H{enforce_v2_gmm?}
H -- Yes --> I[raise RuntimeError with compute cap]
H -- No --> J[return False]
G -- No --> K{NO_SCALING + BF16 + no bias?}
K -- Yes --> L[return True → use V2 path]
K -- No --> M{enforce_v2_gmm?}
M -- Yes --> N[raise RuntimeError with dtype/bias/mode]
M -- No --> O[return False → use V1 path]
Reviews (1): Last reviewed commit: "TE GMM v2 enforcement env var" | Re-trigger Greptile |
| if get_device_compute_capability(0) < 100: | ||
| if enforce_v2_gmm: | ||
| raise RuntimeError( | ||
| "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device" | ||
| f" compute capability of GPU 0 is {get_device_compute_capability(0)} and" | ||
| " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." | ||
| ) |
There was a problem hiding this comment.
Redundant
get_device_compute_capability(0) call
get_device_compute_capability(0) is already called in the if condition on line 1957, and then called a second time inside the error message on line 1961. While the call is likely cheap, it is cleaner and more efficient to capture the result once and reuse it.
| if get_device_compute_capability(0) < 100: | |
| if enforce_v2_gmm: | |
| raise RuntimeError( | |
| "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device" | |
| f" compute capability of GPU 0 is {get_device_compute_capability(0)} and" | |
| " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." | |
| ) | |
| cap = get_device_compute_capability(0) | |
| if cap < 100: | |
| if enforce_v2_gmm: | |
| raise RuntimeError( | |
| "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device" | |
| f" compute capability of GPU 0 is {cap} and" | |
| " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." | |
| ) | |
| return False |
|
/te-ci |
|
Closing as this change merged as part of #2749 |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: